#pragma once

#include <cuda.h>
#include <kfusion/cuda/kernel_containers.hpp>

#define FULL_MASK 0xffffffff

namespace kfusion {
namespace device {
template <class T>
__kf_hdevice__ void swap(T &a, T &b) {
    T c(a);
    a = b;
    b = c;
}

template <typename T>
struct numeric_limits;

template <>
struct numeric_limits<float> {
    __kf_device__ static float quiet_NaN() { return __int_as_float(0x7fffffff); /*CUDART_NAN_F*/ };
    __kf_device__ static float epsilon() { return 1.192092896e-07f /*FLT_EPSILON*/; };
    __kf_device__ static float min() { return 1.175494351e-38f /*FLT_MIN*/; };
    __kf_device__ static float max() { return 3.402823466e+38f /*FLT_MAX*/; };
};

template <>
struct numeric_limits<unsigned short> {
    __kf_device__ static unsigned short max() { return USHRT_MAX; };
};

__kf_device__ float dot(const float3 &v1, const float3 &v2) {
    return __fmaf_rn(v1.x, v2.x, __fmaf_rn(v1.y, v2.y, v1.z * v2.z));
}

__kf_device__ float3 &operator+=(float3 &vec, const float &v) {
    vec.x += v;
    vec.y += v;
    vec.z += v;
    return vec;
}

__kf_device__ float3 &operator+=(float3 &v1, const float3 &v2) {
    v1.x += v2.x;
    v1.y += v2.y;
    v1.z += v2.z;
    return v1;
}

__kf_device__ float3 operator+(const float3 &v1, const float3 &v2) {
    return make_float3(v1.x + v2.x, v1.y + v2.y, v1.z + v2.z);
}

__kf_device__ float3 operator*(const float3 &v1, const float3 &v2) {
    return make_float3(v1.x * v2.x, v1.y * v2.y, v1.z * v2.z);
}

__kf_hdevice__ float3 operator*(const float3 &v1, const int3 &v2) {
    return make_float3(v1.x * v2.x, v1.y * v2.y, v1.z * v2.z);
}

__kf_device__ float3 operator/(const float3 &v1, const float3 &v2) {
    return make_float3(v1.x / v2.x, v1.y / v2.y, v1.z / v2.z);
}

__kf_hdevice__ float3 operator/(const float &v, const float3 &vec) {
    return make_float3(v / vec.x, v / vec.y, v / vec.z);
}

__kf_device__ float3 &operator*=(float3 &vec, const float &v) {
    vec.x *= v;
    vec.y *= v;
    vec.z *= v;
    return vec;
}

__kf_hdevice__ float3 operator-(const float3 &v1, const float3 &v2) {
    return make_float3(v1.x - v2.x, v1.y - v2.y, v1.z - v2.z);
}

__kf_hdevice__ float3 operator*(const float3 &v1, const float &v) { return make_float3(v1.x * v, v1.y * v, v1.z * v); }

__kf_hdevice__ float3 operator*(const float &v, const float3 &v1) { return make_float3(v1.x * v, v1.y * v, v1.z * v); }

__kf_device__ float norm(const float3 &v) { return sqrt(dot(v, v)); }

__kf_device__ float norm_sqr(const float3 &v) { return dot(v, v); }

__kf_device__ float3 normalized(const float3 &v) { return v * rsqrt(dot(v, v)); }

__kf_hdevice__ float3 cross(const float3 &v1, const float3 &v2) {
    return make_float3(v1.y * v2.z - v1.z * v2.y, v1.z * v2.x - v1.x * v2.z, v1.x * v2.y - v1.y * v2.x);
}

__kf_device__ void computeRoots2(const float &b, const float &c, float3 &roots) {
    roots.x = 0.f;
    float d = b * b - 4.f * c;
    if (d < 0.f)  // no real roots!!!! THIS SHOULD NOT HAPPEN!
        d = 0.f;

    float sd = sqrtf(d);

    roots.z = 0.5f * (b + sd);
    roots.y = 0.5f * (b - sd);
}

__kf_device__ void computeRoots3(float c0, float c1, float c2, float3 &roots) {
    if (fabsf(c0) < numeric_limits<float>::epsilon())  // one root is 0 -> quadratic equation
    {
        computeRoots2(c2, c1, roots);
    } else {
        const float s_inv3  = 1.f / 3.f;
        const float s_sqrt3 = sqrtf(3.f);
        // Construct the parameters used in classifying the roots of the equation
        // and in solving the equation for the roots in closed form.
        float c2_over_3 = c2 * s_inv3;
        float a_over_3  = (c1 - c2 * c2_over_3) * s_inv3;
        if (a_over_3 > 0.f)
            a_over_3 = 0.f;

        float half_b = 0.5f * (c0 + c2_over_3 * (2.f * c2_over_3 * c2_over_3 - c1));

        float q = half_b * half_b + a_over_3 * a_over_3 * a_over_3;
        if (q > 0.f)
            q = 0.f;

        // Compute the eigenvalues by solving for the roots of the polynomial.
        float rho       = sqrtf(-a_over_3);
        float theta     = atan2f(sqrtf(-q), half_b) * s_inv3;
        float cos_theta = __cosf(theta);
        float sin_theta = __sinf(theta);
        roots.x         = c2_over_3 + 2.f * rho * cos_theta;
        roots.y         = c2_over_3 - rho * (cos_theta + s_sqrt3 * sin_theta);
        roots.z         = c2_over_3 - rho * (cos_theta - s_sqrt3 * sin_theta);

        // Sort in increasing order.
        if (roots.x >= roots.y)
            swap(roots.x, roots.y);

        if (roots.y >= roots.z) {
            swap(roots.y, roots.z);

            if (roots.x >= roots.y)
                swap(roots.x, roots.y);
        }
        if (roots.x <= 0)  // eigenval for symetric positive semi-definite matrix can
                           // not be negative! Set it to 0
            computeRoots2(c2, c1, roots);
    }
}

struct Eigen33 {
public:
    template <int Rows>
    struct MiniMat {
        float3 data[Rows];
        __kf_hdevice__ float3 &operator[](int i) { return data[i]; }
        __kf_hdevice__ const float3 &operator[](int i) const { return data[i]; }
    };
    typedef MiniMat<3> Mat33;
    typedef MiniMat<4> Mat43;

    static __kf_device__ float3 unitOrthogonal(const float3 &src) {
        float3 perp;
        /* Let us compute the crossed product of *this with a vector
         * that is not too close to being colinear to *this.
         */

        /* unless the x and y coords are both close to zero, we can
         * simply take ( -y, x, 0 ) and normalize it.
         */
        if (!isMuchSmallerThan(src.x, src.z) || !isMuchSmallerThan(src.y, src.z)) {
            float invnm = rsqrtf(src.x * src.x + src.y * src.y);
            perp.x      = -src.y * invnm;
            perp.y      = src.x * invnm;
            perp.z      = 0.0f;
        }
        /* if both x and y are close to zero, then the vector is close
         * to the z-axis, so it's far from colinear to the x-axis for instance.
         * So we take the crossed product with (1,0,0) and normalize it.
         */
        else {
            float invnm = rsqrtf(src.z * src.z + src.y * src.y);
            perp.x      = 0.0f;
            perp.y      = -src.z * invnm;
            perp.z      = src.y * invnm;
        }

        return perp;
    }

    __kf_device__ Eigen33(volatile float *mat_pkg_arg) : mat_pkg(mat_pkg_arg) {}
    __kf_device__ void compute(Mat33 &tmp, Mat33 &vec_tmp, Mat33 &evecs, float3 &evals) {
        // Scale the matrix so its entries are in [-1,1].  The scaling is applied
        // only when at least one matrix entry has magnitude larger than 1.

        float max01 = fmaxf(fabsf(mat_pkg[0]), fabsf(mat_pkg[1]));
        float max23 = fmaxf(fabsf(mat_pkg[2]), fabsf(mat_pkg[3]));
        float max45 = fmaxf(fabsf(mat_pkg[4]), fabsf(mat_pkg[5]));
        float m0123 = fmaxf(max01, max23);
        float scale = fmaxf(max45, m0123);

        if (scale <= numeric_limits<float>::min())
            scale = 1.f;

        mat_pkg[0] /= scale;
        mat_pkg[1] /= scale;
        mat_pkg[2] /= scale;
        mat_pkg[3] /= scale;
        mat_pkg[4] /= scale;
        mat_pkg[5] /= scale;

        // The characteristic equation is x^3 - c2*x^2 + c1*x - c0 = 0.  The
        // eigenvalues are the roots to this equation, all guaranteed to be
        // real-valued, because the matrix is symmetric.
        float c0 = m00() * m11() * m22() + 2.f * m01() * m02() * m12() - m00() * m12() * m12() - m11() * m02() * m02() -
                   m22() * m01() * m01();
        float c1 = m00() * m11() - m01() * m01() + m00() * m22() - m02() * m02() + m11() * m22() - m12() * m12();
        float c2 = m00() + m11() + m22();

        computeRoots3(c0, c1, c2, evals);

        if (evals.z - evals.x <= numeric_limits<float>::epsilon()) {
            evecs[0] = make_float3(1.f, 0.f, 0.f);
            evecs[1] = make_float3(0.f, 1.f, 0.f);
            evecs[2] = make_float3(0.f, 0.f, 1.f);
        } else if (evals.y - evals.x <= numeric_limits<float>::epsilon()) {
            // first and second equal
            tmp[0] = row0();
            tmp[1] = row1();
            tmp[2] = row2();
            tmp[0].x -= evals.z;
            tmp[1].y -= evals.z;
            tmp[2].z -= evals.z;

            vec_tmp[0] = cross(tmp[0], tmp[1]);
            vec_tmp[1] = cross(tmp[0], tmp[2]);
            vec_tmp[2] = cross(tmp[1], tmp[2]);

            float len1 = dot(vec_tmp[0], vec_tmp[0]);
            float len2 = dot(vec_tmp[1], vec_tmp[1]);
            float len3 = dot(vec_tmp[2], vec_tmp[2]);

            if (len1 >= len2 && len1 >= len3) {
                evecs[2] = vec_tmp[0] * rsqrtf(len1);
            } else if (len2 >= len1 && len2 >= len3) {
                evecs[2] = vec_tmp[1] * rsqrtf(len2);
            } else {
                evecs[2] = vec_tmp[2] * rsqrtf(len3);
            }

            evecs[1] = unitOrthogonal(evecs[2]);
            evecs[0] = cross(evecs[1], evecs[2]);
        } else if (evals.z - evals.y <= numeric_limits<float>::epsilon()) {
            // second and third equal
            tmp[0] = row0();
            tmp[1] = row1();
            tmp[2] = row2();
            tmp[0].x -= evals.x;
            tmp[1].y -= evals.x;
            tmp[2].z -= evals.x;

            vec_tmp[0] = cross(tmp[0], tmp[1]);
            vec_tmp[1] = cross(tmp[0], tmp[2]);
            vec_tmp[2] = cross(tmp[1], tmp[2]);

            float len1 = dot(vec_tmp[0], vec_tmp[0]);
            float len2 = dot(vec_tmp[1], vec_tmp[1]);
            float len3 = dot(vec_tmp[2], vec_tmp[2]);

            if (len1 >= len2 && len1 >= len3) {
                evecs[0] = vec_tmp[0] * rsqrtf(len1);
            } else if (len2 >= len1 && len2 >= len3) {
                evecs[0] = vec_tmp[1] * rsqrtf(len2);
            } else {
                evecs[0] = vec_tmp[2] * rsqrtf(len3);
            }

            evecs[1] = unitOrthogonal(evecs[0]);
            evecs[2] = cross(evecs[0], evecs[1]);
        } else {
            tmp[0] = row0();
            tmp[1] = row1();
            tmp[2] = row2();
            tmp[0].x -= evals.z;
            tmp[1].y -= evals.z;
            tmp[2].z -= evals.z;

            vec_tmp[0] = cross(tmp[0], tmp[1]);
            vec_tmp[1] = cross(tmp[0], tmp[2]);
            vec_tmp[2] = cross(tmp[1], tmp[2]);

            float len1 = dot(vec_tmp[0], vec_tmp[0]);
            float len2 = dot(vec_tmp[1], vec_tmp[1]);
            float len3 = dot(vec_tmp[2], vec_tmp[2]);

            float mmax[3];

            unsigned int min_el = 2;
            unsigned int max_el = 2;
            if (len1 >= len2 && len1 >= len3) {
                mmax[2]  = len1;
                evecs[2] = vec_tmp[0] * rsqrtf(len1);
            } else if (len2 >= len1 && len2 >= len3) {
                mmax[2]  = len2;
                evecs[2] = vec_tmp[1] * rsqrtf(len2);
            } else {
                mmax[2]  = len3;
                evecs[2] = vec_tmp[2] * rsqrtf(len3);
            }

            tmp[0] = row0();
            tmp[1] = row1();
            tmp[2] = row2();
            tmp[0].x -= evals.y;
            tmp[1].y -= evals.y;
            tmp[2].z -= evals.y;

            vec_tmp[0] = cross(tmp[0], tmp[1]);
            vec_tmp[1] = cross(tmp[0], tmp[2]);
            vec_tmp[2] = cross(tmp[1], tmp[2]);

            len1 = dot(vec_tmp[0], vec_tmp[0]);
            len2 = dot(vec_tmp[1], vec_tmp[1]);
            len3 = dot(vec_tmp[2], vec_tmp[2]);

            if (len1 >= len2 && len1 >= len3) {
                mmax[1]  = len1;
                evecs[1] = vec_tmp[0] * rsqrtf(len1);
                min_el   = len1 <= mmax[min_el] ? 1 : min_el;
                max_el   = len1 > mmax[max_el] ? 1 : max_el;
            } else if (len2 >= len1 && len2 >= len3) {
                mmax[1]  = len2;
                evecs[1] = vec_tmp[1] * rsqrtf(len2);
                min_el   = len2 <= mmax[min_el] ? 1 : min_el;
                max_el   = len2 > mmax[max_el] ? 1 : max_el;
            } else {
                mmax[1]  = len3;
                evecs[1] = vec_tmp[2] * rsqrtf(len3);
                min_el   = len3 <= mmax[min_el] ? 1 : min_el;
                max_el   = len3 > mmax[max_el] ? 1 : max_el;
            }

            tmp[0] = row0();
            tmp[1] = row1();
            tmp[2] = row2();
            tmp[0].x -= evals.x;
            tmp[1].y -= evals.x;
            tmp[2].z -= evals.x;

            vec_tmp[0] = cross(tmp[0], tmp[1]);
            vec_tmp[1] = cross(tmp[0], tmp[2]);
            vec_tmp[2] = cross(tmp[1], tmp[2]);

            len1 = dot(vec_tmp[0], vec_tmp[0]);
            len2 = dot(vec_tmp[1], vec_tmp[1]);
            len3 = dot(vec_tmp[2], vec_tmp[2]);

            if (len1 >= len2 && len1 >= len3) {
                mmax[0]  = len1;
                evecs[0] = vec_tmp[0] * rsqrtf(len1);
                min_el   = len3 <= mmax[min_el] ? 0 : min_el;
                max_el   = len3 > mmax[max_el] ? 0 : max_el;
            } else if (len2 >= len1 && len2 >= len3) {
                mmax[0]  = len2;
                evecs[0] = vec_tmp[1] * rsqrtf(len2);
                min_el   = len3 <= mmax[min_el] ? 0 : min_el;
                max_el   = len3 > mmax[max_el] ? 0 : max_el;
            } else {
                mmax[0]  = len3;
                evecs[0] = vec_tmp[2] * rsqrtf(len3);
                min_el   = len3 <= mmax[min_el] ? 0 : min_el;
                max_el   = len3 > mmax[max_el] ? 0 : max_el;
            }

            unsigned mid_el = 3 - min_el - max_el;
            evecs[min_el]   = normalized(cross(evecs[(min_el + 1) % 3], evecs[(min_el + 2) % 3]));
            evecs[mid_el]   = normalized(cross(evecs[(mid_el + 1) % 3], evecs[(mid_el + 2) % 3]));
        }
        // Rescale back to the original size.
        evals *= scale;
    }

private:
    volatile float *mat_pkg;

    __kf_device__ float m00() const { return mat_pkg[0]; }
    __kf_device__ float m01() const { return mat_pkg[1]; }
    __kf_device__ float m02() const { return mat_pkg[2]; }
    __kf_device__ float m10() const { return mat_pkg[1]; }
    __kf_device__ float m11() const { return mat_pkg[3]; }
    __kf_device__ float m12() const { return mat_pkg[4]; }
    __kf_device__ float m20() const { return mat_pkg[2]; }
    __kf_device__ float m21() const { return mat_pkg[4]; }
    __kf_device__ float m22() const { return mat_pkg[5]; }

    __kf_device__ float3 row0() const { return make_float3(m00(), m01(), m02()); }
    __kf_device__ float3 row1() const { return make_float3(m10(), m11(), m12()); }
    __kf_device__ float3 row2() const { return make_float3(m20(), m21(), m22()); }

    __kf_device__ static bool isMuchSmallerThan(float x, float y) {
        // copied from <eigen>/include/Eigen/src/Core/NumTraits.h
        const float prec_sqr = numeric_limits<float>::epsilon() * numeric_limits<float>::epsilon();
        return x * x <= prec_sqr * y * y;
    }
};

struct Warp {
    enum { LOG_WARP_SIZE = 5, WARP_SIZE = 1 << LOG_WARP_SIZE, STRIDE = WARP_SIZE };

    /** \brief Returns the warp lane ID of the calling thread. */
    static __kf_device__ unsigned int laneId() {
        unsigned int ret;
        asm("mov.u32 %0, %laneid;" : "=r"(ret));
        return ret;
    }

    static __kf_device__ unsigned int id() {
        int tid = threadIdx.z * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x;
        return tid >> LOG_WARP_SIZE;
    }

    static __kf_device__ int laneMaskLt() {
#if (__CUDA_ARCH__ >= 200)
        unsigned int ret;
        asm("mov.u32 %0, %lanemask_lt;" : "=r"(ret));
        return ret;
#else
        return 0xFFFFFFFF >> (32 - laneId());
#endif
    }
    static __kf_device__ int binaryExclScan(int ballot_mask) { return __popc(Warp::laneMaskLt() & ballot_mask); }
};

struct Block {
    static __kf_device__ unsigned int stride() { return blockDim.x * blockDim.y * blockDim.z; }

    static __kf_device__ int flattenedThreadId() {
        return threadIdx.z * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x;
    }

    template <int CTA_SIZE, typename T, class BinOp>
    static __kf_device__ void reduce(volatile T *buffer, BinOp op) {
        int tid = flattenedThreadId();
        T val   = buffer[tid];

        if (CTA_SIZE >= 1024) {
            if (tid < 512)
                buffer[tid] = val = op(val, buffer[tid + 512]);
            __syncthreads();
        }
        if (CTA_SIZE >= 512) {
            if (tid < 256)
                buffer[tid] = val = op(val, buffer[tid + 256]);
            __syncthreads();
        }
        if (CTA_SIZE >= 256) {
            if (tid < 128)
                buffer[tid] = val = op(val, buffer[tid + 128]);
            __syncthreads();
        }
        if (CTA_SIZE >= 128) {
            if (tid < 64)
                buffer[tid] = val = op(val, buffer[tid + 64]);
            __syncthreads();
        }

        if (tid < 32) {
            if (CTA_SIZE >= 64) {
                buffer[tid] = val = op(val, buffer[tid + 32]);
            }
            if (CTA_SIZE >= 32) {
                buffer[tid] = val = op(val, buffer[tid + 16]);
            }
            if (CTA_SIZE >= 16) {
                buffer[tid] = val = op(val, buffer[tid + 8]);
            }
            if (CTA_SIZE >= 8) {
                buffer[tid] = val = op(val, buffer[tid + 4]);
            }
            if (CTA_SIZE >= 4) {
                buffer[tid] = val = op(val, buffer[tid + 2]);
            }
            if (CTA_SIZE >= 2) {
                buffer[tid] = val = op(val, buffer[tid + 1]);
            }
        }
    }

    template <int CTA_SIZE, typename T, class BinOp>
    static __kf_device__ T reduce(volatile T *buffer, T init, BinOp op) {
        int tid = flattenedThreadId();
        T val = buffer[tid] = init;
        __syncthreads();

        if (CTA_SIZE >= 1024) {
            if (tid < 512)
                buffer[tid] = val = op(val, buffer[tid + 512]);
            __syncthreads();
        }
        if (CTA_SIZE >= 512) {
            if (tid < 256)
                buffer[tid] = val = op(val, buffer[tid + 256]);
            __syncthreads();
        }
        if (CTA_SIZE >= 256) {
            if (tid < 128)
                buffer[tid] = val = op(val, buffer[tid + 128]);
            __syncthreads();
        }
        if (CTA_SIZE >= 128) {
            if (tid < 64)
                buffer[tid] = val = op(val, buffer[tid + 64]);
            __syncthreads();
        }

        if (tid < 32) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 300 && 0
            if (CTA_SIZE >= 64)
                val = op(val, buffer[tid + 32]);
            if (CTA_SIZE >= 32)
                val = op(val, __shfl_xor(val, 16));
            if (CTA_SIZE >= 16)
                val = op(val, __shfl_xor(val, 8));
            if (CTA_SIZE >= 8)
                val = op(val, __shfl_xor(val, 4));
            if (CTA_SIZE >= 4)
                val = op(val, __shfl_xor(val, 2));
            if (CTA_SIZE >= 2)
                buffer[tid] = op(val, __shfl_xor(val, 1));
#else
            if (CTA_SIZE >= 64) {
                buffer[tid] = val = op(val, buffer[tid + 32]);
            }
            if (CTA_SIZE >= 32) {
                buffer[tid] = val = op(val, buffer[tid + 16]);
            }
            if (CTA_SIZE >= 16) {
                buffer[tid] = val = op(val, buffer[tid + 8]);
            }
            if (CTA_SIZE >= 8) {
                buffer[tid] = val = op(val, buffer[tid + 4]);
            }
            if (CTA_SIZE >= 4) {
                buffer[tid] = val = op(val, buffer[tid + 2]);
            }
            if (CTA_SIZE >= 2) {
                buffer[tid] = val = op(val, buffer[tid + 1]);
            }
#endif
        }
        __syncthreads();
        return buffer[0];
    }
};

struct Emulation {
    static __kf_device__ int warp_reduce(volatile int *ptr, const unsigned int tid) {
        const unsigned int lane = tid & 31;  // index of thread in warp (0..31)

        if (lane < 16) {
            int partial = ptr[tid];

            ptr[tid] = partial = partial + ptr[tid + 16];
            ptr[tid] = partial = partial + ptr[tid + 8];
            ptr[tid] = partial = partial + ptr[tid + 4];
            ptr[tid] = partial = partial + ptr[tid + 2];
            ptr[tid] = partial = partial + ptr[tid + 1];
        }
        return ptr[tid - lane];
    }

    static __kf_device__ int Ballot(int predicate, volatile int *cta_buffer) {
#if __CUDA_ARCH__ >= 200
        (void) cta_buffer;
        return __ballot_sync(FULL_MASK, predicate);
#else
        int tid         = Block::flattenedThreadId();
        cta_buffer[tid] = predicate ? (1 << (tid & 31)) : 0;
        return warp_reduce(cta_buffer, tid);
#endif
    }

    static __kf_device__ bool All(int predicate, volatile int *cta_buffer) {
#if __CUDA_ARCH__ >= 200
        (void) cta_buffer;
        return __all_sync(FULL_MASK, predicate);
#else
        int tid         = Block::flattenedThreadId();
        cta_buffer[tid] = predicate ? 1 : 0;
        return warp_reduce(cta_buffer, tid) == 32;
#endif
    }
};
}  // namespace device
}  // namespace kfusion
